{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using torchscript with Classy Vision"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[torchscript](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html) is commonly used to export PyTorch models from Python to C++. This is useful for productionizing models, when you typically perform inference on a CPU. This tutorial will demonstrate how to export a Classy Vision model using `torchscript`'s tracing mode and how to load a torchscript model.\n",
    "\n",
    "## 1. Build and train the model\n",
    "\n",
    "Our [Getting started](https://classyvision.ai/tutorials/getting_started) tutorial covered many ways of training a model, here we'll simply instantiate a ResNeXT model from a config:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from classy_vision.models import build_model\n",
    "import torch\n",
    "\n",
    "config = {\n",
    "    \"name\": \"resnext\",\n",
    "    \"num_blocks\": [3, 4, 23, 3],\n",
    "    \"num_classes\": 1000,\n",
    "    \"base_width_and_cardinality\": [4, 32],\n",
    "    \"small_input\": False,\n",
    "    \"heads\": [\n",
    "        {\n",
    "        \"name\": \"fully_connected\",\n",
    "        \"unique_id\": \"default_head\",\n",
    "        \"num_classes\": 1000,\n",
    "        \"fork_block\": \"block3-2\",\n",
    "        \"in_plane\": 2048\n",
    "        }\n",
    "    ]\n",
    "}\n",
    "\n",
    "model = build_model(config)\n",
    "\n",
    "# set the model in eval mode after training to use for inference\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Export the model\n",
    "\n",
    "Now that the model is built/trained, you can export it using `torch.jit.trace`. To check the results, we'll perform inference on the actual model and on the torchscripted model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    script = torch.jit.trace(model, torch.randn(1, 3, 224, 224, dtype=torch.float))\n",
    "    input = torch.randn(1, 3, 224, 224, dtype=torch.float)\n",
    "    origin_outs = model(input)\n",
    "    script_outs = script(input)\n",
    "\n",
    "assert torch.allclose(origin_outs, script_outs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After verifying the torchscripted model works as expected, you can save it using `torch.jit.save`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.jit.save(script, \"/tmp/resnext_101.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Loading a model\n",
    "\n",
    "Loading a torchscripted model is as simple as calling `torch.jit.load`. If you need to fine-tune or continue training the model, the loaded model can be attached directly to a `ClassificationTask` or `FineTuningTask` in Classy Vision:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_model = torch.jit.load(\"/tmp/resnext_101.pt\")\n",
    "loaded_outs = loaded_model(input)\n",
    "\n",
    "assert torch.allclose(loaded_outs, origin_outs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Conclusion\n",
    "\n",
    "`torchscript` makes it really easy to transfer models between research and production with PyTorch, and it works seamlessly with Classy Vision. Check out the [torchscript tutorial](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html) for more information about how to export a model correctly. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
